"""
#
# Flow stability for dynamic community detection https://arxiv.org/abs/2101.06131
#
# Copyright (C) 2021 Alexandre Bovet <alexandre.bovet@maths.ox.ac.uk>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


This script plots the clustering of the APS dataset shown in Figs. 9 and S5.

"""



import sys
import os
PACKAGE_PARENT = '..'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))

import numpy as np

import pandas as pd

from TemporalNetwork import ContTempNetwork


from collections import Counter

import graph_tool.all as gt

import matplotlib.patches as patches
import matplotlib.pyplot as plt

color_list =["#ffffff",
"#4ba706",
"#a2007e",
"#806dcb",
"#5eb275",
"#ca3b01",
"#01a4d6",
"#b77600",
"#a39643",
"#cc6ea9",
"#1e5e39",
"#cb5b5a"]


raise Exception

#%%

net_name = 'aps_monthly'


datadir = '../paper_data/aps/'

# this is the author name disambiguation of the APS dataset
# available at Supplementary Material at https://doi.org/10.1126/science.aaf5239
df_authors = pd.read_csv('../data/aps/APS_authors.csv', index_col=0)


tau_w = 3650.0
tau_ws = [tau_w]

net_file = os.path.join('../paper_data/aps','aps_monthly_lcc_net')

net = ContTempNetwork.load(net_file, attributes_list = ['node_to_label_dict',
                          'events_table',
                          'times',
                          'time_grid',
                          'num_nodes',
                          'time_slices_bounds',
                          'time_slices_bounds_datetimes'])

num_nodes = net.num_nodes
time_slices = net.time_slices_bounds

label_to_node_dict = {l : n for n, l in net.node_to_label_dict.items()}


#%% 

multi_res = pd.read_pickle(os.path.join('../paper_data/aps', 'multi_complenet_partitions_all_comps_active_nodes_sorted.pickle'))

network='complenet'


#%%

# list of journals in which each author published
df_auth_journal = pd.read_csv('../data/aps/df_author_journal.csv', index_col='author')

class_dict_journ = {journ : set([label_to_node_dict[l] for \
                               l in df.index.tolist() if l in label_to_node_dict]) for \
              journ, df in df_auth_journal.groupby('top_non_prl')}
class_dict_journ['other'] = set([label_to_node_dict[l] for l in \
                           df_auth_journal.loc[df_auth_journal.top_non_prl.isnull()].index.tolist()\
                               if l in label_to_node_dict])
    
node_to_journal_dict = {label_to_node_dict[a] : c for a,c in df_auth_journal.top_non_prl.iteritems()\
                                if a in label_to_node_dict}
    
# list of countries (in affiliations) for each author
df_auth_country = pd.read_csv('../data/aps/df_author_country.csv', index_col='author')
df_auth_country.top_country.loc[df_auth_country.top_country.isna()] = 'undefined'

class_dict_country = {count : set([label_to_node_dict[l] for \
                               l in df.index.tolist() if l in label_to_node_dict]) for \
              count, df in df_auth_country.groupby('top_country')}
    

node_to_country_dict = {label_to_node_dict[a] : c for a,c in df_auth_country.top_country.iteritems()\
                                if a in label_to_node_dict}

#%% pick initial cluster    
    
#cluster from forw 00s
orig_clust_id = 3 #UK/Belgian cluster
orig_clust_name = 'UK-Finland-Belgium'
orig_clust = multi_res[tau_w]['partitions_forw'][0].cluster_list[orig_clust_id]
names1=[(df_authors.iloc[net.node_to_label_dict[n]]['name'], 
        node_to_country_dict[n], 
        node_to_journal_dict[n]) for n in orig_clust]


orig_clust_id = 18 
orig_clust_name = 'USA-Hungary'
orig_clust = multi_res[tau_w]['partitions_forw'][0].cluster_list[orig_clust_id]
names2=[(df_authors.iloc[net.node_to_label_dict[n]]['name'], 
        node_to_country_dict[n], 
        node_to_journal_dict[n]) for n in orig_clust]


orig_clust_id = 1
orig_clust_name = 'USA-Italy'
orig_clust = multi_res[tau_w]['partitions_forw'][0].cluster_list[orig_clust_id]
names3=[(df_authors.iloc[net.node_to_label_dict[n]]['name'], 
        node_to_country_dict[n], 
        node_to_journal_dict[n]) for n in orig_clust]


orig_clust = multi_res[tau_w]['partitions_forw'][0].cluster_list[orig_clust_id]
names=[(df_authors.iloc[net.node_to_label_dict[n]]['name'], 
        node_to_country_dict[n], 
        node_to_journal_dict[n]) for n in orig_clust]
print(names)

Counter([c for _,c,_ in names]).most_common()

print(pd.DataFrame(names).to_latex(index=False))

#%% print paper table

df1 = pd.DataFrame(names1)
df1[0]=df1[0].str.title()
df1 = df1.sort_values(0).reset_index(drop=True)

df2 = pd.DataFrame(names2)
df2[0]=df2[0].str.title()
df2 = df2.sort_values(0).reset_index(drop=True)

df3 = pd.DataFrame(names3)
df3[0]=df3[0].str.title()
df3 = df3.sort_values(0).reset_index(drop=True)

print(pd.concat([df1,df2,df3],axis=1).to_latex(index=False, na_rep=""))


#%% load prob flow between clusters results

dire = 'forw'

pflow = pd.read_pickle(os.path.join('../paper_data/aps',f'{network}_{dire}_tau_w{tau_w:.3e}_prob_flow_integ_{orig_clust_name}_clust.pickle'))

#%% build network starting from backward 2010 cluster


G = gt.Graph()

edge_weights = G.new_edge_property('double')
node_weights = G.new_vertex_property('double')


v_root = G.add_vertex(1)
node_weights[v_root] = len(orig_clust)
vertex_to_clust_dict = {v_root : orig_clust}

v_roots = [v_root]
#%% top clusts from the 90s
    
v_90s_list = list(G.add_vertex(len(pflow['clusts_90s_forw'])))

for i, (clust, v_i) in enumerate(zip(pflow['clusts_90s_forw'],v_90s_list)):
    

    vertex_to_clust_dict[v_i] = clust
    node_weights[v_i] = len(clust)

    e_i = G.add_edge(v_root, v_i)
    edge_weights[e_i] = np.sort(pflow['T_clust_clust_2010_to_2000'])[::-1][i]
    
    


#%% add nodes from the 80s


v_80s_list = list(G.add_vertex(len(pflow['clusts_80s_forw'])))

for i, (clust, v_i) in enumerate(zip(pflow['clusts_80s_forw'],v_80s_list)):
    

    vertex_to_clust_dict[v_i] = clust
    node_weights[v_i] = len(clust)

    e_i = G.add_edge(v_90s_list[0], v_i)
    edge_weights[e_i] = np.sort(pflow['T_clust_clust_2000_to_1990'][0])[::-1][i]


#%% add edges between level 90s and level 80s from other nodes


# starting from v_90s_1:
for k in range(1,len(v_90s_list)):
    print(k)
    for i, idx in enumerate(pflow['clusts_80s_2000_to_1990_forw_idx'][k]):
        if idx in pflow['clusts_80s_forw_idx']:
            where = np.where(pflow['clusts_80s_forw_idx'] == idx)[0][0]
            
            e_i = G.add_edge(v_90s_list[k], v_80s_list[where])
            edge_weights[e_i] = np.sort(pflow['T_clust_clust_2000_to_1990'][k])[::-1][i]
            print(k,where)
        
        

#%% top clusts from 70s

v_70s_list = list(G.add_vertex(len(pflow['clusts_70s_forw'])))


for i, (clust, v_i) in enumerate(zip(pflow['clusts_70s_forw'],v_70s_list)):
    

    vertex_to_clust_dict[v_i] = clust
    node_weights[v_i] = len(clust)

    e_i = G.add_edge(v_80s_list[0], v_i)
    edge_weights[e_i] = np.sort(pflow['T_clust_clust_1990_to_1980'][0])[::-1][i]
    
    
    


#%% add edges between level 80s and 70s from other nodes

# starting from v_80s_1:
for k in range(1,len(v_80s_list)):
    print(k)
    for i, idx in enumerate(pflow['clusts_70s_1990_to_1980_forw_idx'][k]):
        if idx in pflow['clusts_70s_forw_idx']:
            where = np.where(pflow['clusts_70s_forw_idx'] == idx)[0][0]
            
            e_i = G.add_edge(v_80s_list[k], v_70s_list[where])
            edge_weights[e_i] = np.sort(pflow['T_clust_clust_1990_to_1980'][k])[::-1][i]
            print(k,where)
            
            

#%% make combined graph

#%%
#cluster from forw 00s
orig_clust_id_1 = 3 #UK/Belgian cluster
orig_clust_name_1 = 'UK-Finland-Belgium'


orig_clust_id_2 = 18 
orig_clust_name_2 = 'USA-Hungary'

orig_clust_id_3 = 1 
orig_clust_name_3 = 'USA-Italy'


orig_clust_1 = multi_res[tau_w]['partitions_forw'][0].cluster_list[orig_clust_id_1]
orig_clust_2 = multi_res[tau_w]['partitions_forw'][0].cluster_list[orig_clust_id_2]
orig_clust_3 = multi_res[tau_w]['partitions_forw'][0].cluster_list[orig_clust_id_3]

orig_clust_name = "three_roots_UK-Finland-Belgium_and_USA-Hungary_and_USA_Italy"
#%% load prob flow between clusters results

dire = 'forw'
pflow1 = pd.read_pickle(os.path.join('../paper_data/aps',f'{network}_{dire}_tau_w{tau_w:.3e}_prob_flow_integ_{orig_clust_name_1}_clust.pickle'))
pflow2 = pd.read_pickle(os.path.join('../paper_data/aps',f'{network}_{dire}_tau_w{tau_w:.3e}_prob_flow_integ_{orig_clust_name_2}_clust.pickle'))
pflow3 = pd.read_pickle(os.path.join('../paper_data/aps',f'{network}_{dire}_tau_w{tau_w:.3e}_prob_flow_integ_{orig_clust_name_3}_clust.pickle'))

#%% build network starting from backward 2010 cluster

G = gt.Graph()

edge_weights = G.new_edge_property('double')
node_weights = G.new_vertex_property('double')


v_root_1 = G.add_vertex(1)
node_weights[v_root_1] = len(orig_clust_1)

v_root_2 = G.add_vertex(1)
node_weights[v_root_2] = len(orig_clust_2)

v_root_3 = G.add_vertex(1)
node_weights[v_root_3] = len(orig_clust_3)

vertex_to_clust_dict = {v_root_1 : orig_clust_1,
                        v_root_2 : orig_clust_2,
                        v_root_3 : orig_clust_3}

v_roots = [v_root_1, v_root_2, v_root_3]

#%% top clusts from the 90s
    
# check if there are any overlap
# assert len(set(pflow3['clusts_90s_forw_idx']) & set(pflow2['clusts_90s_forw_idx'])) == 0
# assert len(set(pflow3['clusts_90s_forw_idx']) & set(pflow1['clusts_90s_forw_idx'])) == 0
# assert len(set(pflow1['clusts_90s_forw_idx']) & set(pflow2['clusts_90s_forw_idx'])) == 0

#overlap between 3 and 1
clusts_idx_90s = sorted(list(set(pflow3['clusts_90s_forw_idx']) | \
                 set(pflow2['clusts_90s_forw_idx']) | \
                 set(pflow1['clusts_90s_forw_idx'])))

v_90s_list = list(G.add_vertex(len(clusts_idx_90s)))


v_90s_list_1 = [v_90s_list[clusts_idx_90s.index(idx)] for idx in pflow1['clusts_90s_forw_idx']]

for i, (clust, v_i) in enumerate(zip(pflow1['clusts_90s_forw'],v_90s_list_1)):
    

    vertex_to_clust_dict[v_i] = clust
    node_weights[v_i] = len(clust)

    e_i = G.add_edge(v_root_1, v_i)
    edge_weights[e_i] = np.sort(pflow1['T_clust_clust_2010_to_2000'])[::-1][i]
    
    
v_90s_list_2 = [v_90s_list[clusts_idx_90s.index(idx)] for idx in pflow2['clusts_90s_forw_idx']]

for i, (clust, v_i) in enumerate(zip(pflow2['clusts_90s_forw'],v_90s_list_2)):
    

    vertex_to_clust_dict[v_i] = clust
    node_weights[v_i] = len(clust)

    e_i = G.add_edge(v_root_2, v_i)
    edge_weights[e_i] = np.sort(pflow2['T_clust_clust_2010_to_2000'])[::-1][i]
    
v_90s_list_3 = [v_90s_list[clusts_idx_90s.index(idx)] for idx in pflow3['clusts_90s_forw_idx']]

for i, (clust, v_i) in enumerate(zip(pflow3['clusts_90s_forw'],v_90s_list_3)):
    

    vertex_to_clust_dict[v_i] = clust
    node_weights[v_i] = len(clust)

    e_i = G.add_edge(v_root_3, v_i)
    edge_weights[e_i] = np.sort(pflow3['T_clust_clust_2010_to_2000'])[::-1][i]    
        

#%% add nodes from the 80s

# assert not np.isin(pflow1['clusts_80s_forw_idx'][0],pflow2['clusts_80s_forw_idx'][0]).any()
# assert not np.isin(pflow2['clusts_80s_forw_idx'][0],pflow1['clusts_80s_forw_idx'][0]).any()

clusts_idx_80s = sorted(list(set(pflow3['clusts_80s_forw_idx']) | \
                 set(pflow2['clusts_80s_forw_idx']) | \
                 set(pflow1['clusts_80s_forw_idx'])))
    
v_80s_list = list(G.add_vertex(len(clusts_idx_80s)))


v_80s_list_1 = [v_80s_list[clusts_idx_80s.index(idx)] for idx in pflow1['clusts_80s_forw_idx']]    

for i, (clust, v_i) in enumerate(zip(pflow1['clusts_80s_forw'],v_80s_list_1)):
    

    vertex_to_clust_dict[v_i] = clust
    node_weights[v_i] = len(clust)


v_80s_list_2 = [v_80s_list[clusts_idx_80s.index(idx)] for idx in pflow2['clusts_80s_forw_idx']]    

for i, (clust, v_i) in enumerate(zip(pflow2['clusts_80s_forw'],v_80s_list_2)):
    

    vertex_to_clust_dict[v_i] = clust
    node_weights[v_i] = len(clust)



v_80s_list_3 = [v_80s_list[clusts_idx_80s.index(idx)] for idx in pflow3['clusts_80s_forw_idx']]    

for i, (clust, v_i) in enumerate(zip(pflow3['clusts_80s_forw'],v_80s_list_3)):
    

    vertex_to_clust_dict[v_i] = clust
    node_weights[v_i] = len(clust)
    
#%% add edges between level 90s and level 80s 


# starting from v_90s_1:
target_idx_label = 'clusts_80s_forw_idx'    
for v_source_list, pflow_source_idx, pflow_source in zip([v_90s_list_1, v_90s_list_2, v_90s_list_3],
                                       [pflow1['clusts_80s_2000_to_1990_forw_idx'],
                                        pflow2['clusts_80s_2000_to_1990_forw_idx'],
                                        pflow3['clusts_80s_2000_to_1990_forw_idx']],
                                       [pflow1['T_clust_clust_2000_to_1990'],
                                        pflow2['T_clust_clust_2000_to_1990'],
                                        pflow3['T_clust_clust_2000_to_1990']]):    
    for k in range(len(v_source_list)):
        print(k)
        for i, idx in enumerate(pflow_source_idx[k]):
            if idx in pflow1[target_idx_label]:
                where = np.where(pflow1[target_idx_label] == idx)[0][0]
                
                e_i = G.add_edge(v_source_list[k], v_80s_list_1[where])
                edge_weights[e_i] = np.sort(pflow_source[k])[::-1][i]
                print(k,where)
                
            if idx in pflow2[target_idx_label]:
                where = np.where(pflow2[target_idx_label] == idx)[0][0]
                
                e_i = G.add_edge(v_source_list[k], v_80s_list_2[where])
                edge_weights[e_i] = np.sort(pflow_source[k])[::-1][i]
                print(k,where)
            
            if idx in pflow3[target_idx_label]:
                where = np.where(pflow3[target_idx_label] == idx)[0][0]
                
                e_i = G.add_edge(v_source_list[k], v_80s_list_3[where])
                edge_weights[e_i] = np.sort(pflow_source[k])[::-1][i]
                print(k,where)

            


#%% top clusts from 70s

clusts_idx_70s = sorted(list(set(pflow3['clusts_70s_forw_idx']) | \
                 set(pflow2['clusts_70s_forw_idx']) | \
                 set(pflow1['clusts_70s_forw_idx'])))
    
v_70s_list = list(G.add_vertex(len(clusts_idx_70s)))


v_70s_list_1 = [v_70s_list[clusts_idx_70s.index(idx)] for idx in pflow1['clusts_70s_forw_idx']]    

for i, (clust, v_i) in enumerate(zip(pflow1['clusts_70s_forw'],v_70s_list_1)):
    

    vertex_to_clust_dict[v_i] = clust
    node_weights[v_i] = len(clust)


v_70s_list_2 = [v_70s_list[clusts_idx_70s.index(idx)] for idx in pflow2['clusts_70s_forw_idx']]    

for i, (clust, v_i) in enumerate(zip(pflow2['clusts_70s_forw'],v_70s_list_2)):
    

    vertex_to_clust_dict[v_i] = clust
    node_weights[v_i] = len(clust)



v_70s_list_3 = [v_70s_list[clusts_idx_70s.index(idx)] for idx in pflow3['clusts_70s_forw_idx']]    

for i, (clust, v_i) in enumerate(zip(pflow3['clusts_70s_forw'],v_70s_list_3)):
    

    vertex_to_clust_dict[v_i] = clust
    node_weights[v_i] = len(clust)

#%% add edges between level 80s and level 70s

# starting from v_90s_1:
target_idx_label = 'clusts_70s_forw_idx'    
for v_source_list, pflow_source_idx, pflow_source in zip([v_80s_list_1, v_80s_list_2, v_80s_list_3],
                                       [pflow1['clusts_70s_1990_to_1980_forw_idx'],
                                        pflow2['clusts_70s_1990_to_1980_forw_idx'],
                                        pflow3['clusts_70s_1990_to_1980_forw_idx']],
                                       [pflow1['T_clust_clust_1990_to_1980'],
                                        pflow2['T_clust_clust_1990_to_1980'],
                                        pflow3['T_clust_clust_1990_to_1980']]):    
    for k in range(len(v_source_list)):
        print(k)
        for i, idx in enumerate(pflow_source_idx[k]):
            if idx in pflow1[target_idx_label]:
                where = np.where(pflow1[target_idx_label] == idx)[0][0]
                
                e_i = G.add_edge(v_source_list[k], v_70s_list_1[where])
                edge_weights[e_i] = np.sort(pflow_source[k])[::-1][i]
                print(k,where)
                
            if idx in pflow2[target_idx_label]:
                where = np.where(pflow2[target_idx_label] == idx)[0][0]
                
                e_i = G.add_edge(v_source_list[k], v_70s_list_2[where])
                edge_weights[e_i] = np.sort(pflow_source[k])[::-1][i]
                print(k,where)
            
            if idx in pflow3[target_idx_label]:
                where = np.where(pflow3[target_idx_label] == idx)[0][0]
                
                e_i = G.add_edge(v_source_list[k], v_70s_list_3[where])
                edge_weights[e_i] = np.sort(pflow_source[k])[::-1][i]
                print(k,where)

#%% 
gt.remove_parallel_edges(G)

#%% save graph

G.ep['weights'] = edge_weights
G.vp['weights'] = node_weights

G.save(f'../paper_data/aps/{network}_flow_prob/{dire}_flow_prob_graph_{tau_w:.3e}_{orig_clust_name}_clust.gt')

vertex_idx_to_clust_dict = { G.vertex_index[v] : c for v,c in vertex_to_clust_dict.items()}



pd.to_pickle({'edge_weights' : edge_weights,
              'node_weights' : node_weights,
              'vertex_idx_to_clust_dict' : vertex_idx_to_clust_dict,
               'v_idx_root' : [G.vertex_index[v] for v in v_roots],
              'v_90s_idx_list' : [G.vertex_index[v] for v in v_90s_list],
              'v_80s_idx_list' : [G.vertex_index[v] for v in v_80s_list],
              'v_70s_idx_list' : [G.vertex_index[v] for v in v_70s_list]},
             f'../paper_data/aps/{network}_flow_prob/{dire}_flow_prob_graph_{tau_w:.3e}_{orig_clust_name}_clust_graph_attr.pickle')

#%%
G = gt.load_graph(f'../paper_data/aps/{network}_flow_prob/{dire}_flow_prob_graph_{tau_w:.3e}_{orig_clust_name}_clust.gt')

attr = pd.read_pickle(f'../paper_data/aps/{network}_flow_prob/{dire}_flow_prob_graph_{tau_w:.3e}_{orig_clust_name}_clust_graph_attr.pickle')

edge_weights = G.ep['weights']
node_weights = G.vp['weights']
vertex_to_clust_dict = {G.vertex(vid): v for vid,v in attr['vertex_idx_to_clust_dict'].items()}
v_90s_list = [G.vertex(v) for v in attr['v_90s_idx_list']]
v_80s_list = [G.vertex(v) for v in attr['v_80s_idx_list']]
v_70s_list = [G.vertex(v) for v in attr['v_70s_idx_list']]
v_roots = [G.vertex(v) for v in attr['v_idx_root']]

#%% draw graph





#%% custom layout
from wquantiles import median

hier_lay = G.new_vertex_property('vector<double>')

#(0,0) is upper-left corner
X_max = 20
Y_max = 10

X_mean = X_max/2

num_layers = 4
k_in = G.degree_property_map('in')
k_tot = G.degree_property_map('total',weight=edge_weights)


# root at the lower-left corner

# hier_lay[0] = [0,Y_max]

# first pass
# sort v_list to have the vertices with the largest indegree in the middle
for k, vlist in zip(range(num_layers), [v_roots, v_90s_list, v_80s_list, v_70s_list]):
    vlist = sorted(vlist, key=lambda v:k_tot[v])
    vlist = vlist[len(vlist)%2::2] + vlist[-1::-2]
    
    for i,v in enumerate(vlist):
        # hier_lay[v] = [i*X_max/(len(vlist)-1), Y_max - k*Y_max/(num_layers-1)]
        hier_lay[v] = [(i+1)* X_max/(len(vlist)+1), Y_max - k*Y_max/(num_layers-1)]
        
# now try to minimize edge crossing:
def mean_pos_in(v):
    # median position of the in-neighbors
    
    # weighted mean version:
    return np.array([hier_lay[e.source()][0]*edge_weights[e]/sum([edge_weights[e] for e in v.in_edges()]) for e in v.in_edges()]).sum()

def median_pos_in(v):
    return median(np.array([hier_lay[e.source()][0] for e in v.in_edges()]), np.array([edge_weights[e] for e in v.in_edges()]))
    
def mean_pos_out(v):
    # median position of the out-neighbors
    # return(np.median([hier_lay[i][0] for i in v.in_neighbors()]))
    # weighted version:
    return np.array([hier_lay[e.target()][0]*edge_weights[e]/sum([edge_weights[e] for e in v.out_edges()]) for e in v.out_edges()]).sum()

def median_pos_out(v):
    return median(np.array([hier_lay[e.target()][0] for e in v.out_edges()]), np.array([edge_weights[e] for e in v.out_edges()]))
    


# go from 90s to 70s using pos from below
# sort vroots

v_roots = sorted(v_roots, key=lambda v:k_tot[v])
v_roots = v_roots[len(v_roots)%2::2] + v_roots[-1::-2]

for k, vlist in zip(range(1,num_layers), [v_90s_list, v_80s_list, v_70s_list]):
    # sort v_list to have the vertices with the largest indegree in the middle    
    med_pos = {v:median_pos_in(v) for v in vlist}
    # k_tot_pos = [k_tot[v] for v in vlist]
    
    vlist = sorted(vlist, key=lambda v: (med_pos[v], k_tot[v]))
    
    for i,v in enumerate(vlist):
        # hier_lay[v] = [i*X_max/(len(vlist)-1), Y_max - k*Y_max/(num_layers-1)]
        hier_lay[v] = [(i+1)* X_max/(len(vlist)+1), Y_max - k*Y_max/(num_layers-1)]

# try to optimize the middle layers
for _ in range(10): # number of iterations


    # go from 80s to 90s using pos from above
    for k, vlist in zip(range(2,0,-1), [v_80s_list, v_90s_list]):
        
        med_pos = {v:median_pos_out(v) for v in vlist}
        
        
        vlist = sorted(vlist, key=lambda v: (med_pos[v], hier_lay[v][0]))
        
        
        for i,v in enumerate(vlist):
            # hier_lay[v] = [i*X_max/(len(vlist)-1), Y_max - k*Y_max/(num_layers-1)]
            hier_lay[v] = [(i+1)* X_max/(len(vlist)+1), Y_max - k*Y_max/(num_layers-1)]
                     
    # now from 90s to 80s
    for k, vlist in zip(range(1,3), [v_90s_list, v_80s_list]):
        
        med_pos = {v:median_pos_in(v) for v in vlist}
    
        vlist = sorted(vlist, key=lambda v: (med_pos[v], hier_lay[v][0]))
        
        for i,v in enumerate(vlist):
            # hier_lay[v] = [i*X_max/(len(vlist)-1), Y_max - k*Y_max/(num_layers-1)]
            hier_lay[v] = [(i+1)* X_max/(len(vlist)+1), Y_max - k*Y_max/(num_layers-1)]
#%% add text

vertex_text = G.new_vertex_property('string')

for v in G.vertices():
    vertex_text[v] = '{0:.0f}'.format(node_weights[v])

edge_text = G.new_edge_property('string')

for e in G.edges():
    edge_text[e] = '{0:.2f}'.format(edge_weights[e])
    
#%% make country pie fractions

vertex_country_counters = {v : Counter([node_to_country_dict[n] for n in c]) for v,c in vertex_to_clust_dict.items()}


all_nodes_country_counter = Counter([node_to_country_dict[n] for n in set.union(*vertex_to_clust_dict.values())])


top_countries = [c for c,_ in all_nodes_country_counter.most_common(5)]


if orig_clust_name ==  'USA-Hungary':
    top_countries += ['Hungary']
elif orig_clust_name == 'UK-Finland-Belgium':
    top_countries += ['Finland','Spain']
elif orig_clust_name == "two_roots_'UK-Finland-Belgium_and_USA-Hungary":
    top_countries += ['Finland','Spain','Hungary']
elif orig_clust_name == 'USA-Italy':
    top_countries += ['Spain']
elif orig_clust_name == 'three_roots_UK-Finland-Belgium_and_USA-Hungary_and_USA_Italy':    
    top_countries += ['Italy','Hungary','Finland','Spain']


pie_fracs = G.new_vertex_property("vector<double>") 
for v in G.vertices():
    top_country_counts = [vertex_country_counters[v][country] if country in vertex_country_counters[v] else 0 for country in top_countries ]
    top_country_counts.append(len(vertex_to_clust_dict[v]) - sum(top_country_counts))
    pie_fracs[v] = np.array(top_country_counts)/len(vertex_to_clust_dict[v])

country_color_list = color_list.copy()
country_color_list[8] = country_color_list[10] # replace brown by dark green
country_color_list = country_color_list[1:1+len(top_countries)]
country_color_list += ["#808080"] # grey for "rest" category

#%%
donut_holes = list(G.add_vertex(G.num_vertices()))
#%%
vertex_shape = G.new_vertex_property('string')
vertex_fill_color = G.new_vertex_property('vector<double>')
vertex_size = G.new_vertex_property('double')

for v, d in zip(range(len(donut_holes)),donut_holes):
    hier_lay[d] = hier_lay[v]
    vertex_size[v] = gt.prop_to_size(node_weights, mi=40, ma=100)[v]
    vertex_size[d] = gt.prop_to_size(node_weights, mi=40, ma=100)[v]*0.6
    vertex_shape[v] = 'pie'
    vertex_shape[d] = 'circle'
    vertex_fill_color[d] = [1,1,1,1]
    vertex_text[d] = vertex_text[v] 

#%% draw country graph

G_filt = gt.GraphView(G, efilt=[edge_weights[e] > 0.05 for e in G.edges()])

gt.graph_draw(G_filt, pos=hier_lay,
              vertex_size=vertex_size,
              vertex_pen_width=0,
              edge_pen_width=gt.prop_to_size(edge_weights,mi=2,ma=20,power=1),
              nodesfirst=True,
              vertex_shape=vertex_shape,
              vertex_fill_color=vertex_fill_color,
              vertex_pie_fractions=pie_fracs,
              vertex_rotation=-1*np.pi/2,
              vertex_text_rotation=np.pi/2,
              vertex_text_color=[0,0,0,1],
              vertex_pie_colors=country_color_list,
              output_size=(1500,900),
              vertex_text=vertex_text,
              # edge_text=edge_text,
              edge_color=[0.179, 0.203,0.210, 0.4],
               # fit_view=(0, 1, 1, 12) ,
              vertex_font_family='Helvetica',
              edge_font_family='Helvetica',
                # output=f'../figures/aps/complenet/complenet_forw_flow_prob_{orig_clust_name}_country.pdf',
                # output=f'../figures/aps/complenet/complenet_forw_flow_prob_{orig_clust_name}_country.svg',
              )


#%% journal pies
vertex_journal_counters = {v : Counter([node_to_journal_dict[n] for n in c]) for v,c in vertex_to_clust_dict.items()}


all_nodes_journal_counter = Counter([node_to_journal_dict[n] for n in set.union(*vertex_to_clust_dict.values())])


top_journals = [c for c,_ in all_nodes_journal_counter.most_common(10)]

top_journals = ['PhysRevB',
 'PhysRevC',
 'PhysRevD',
 'PhysRevA',
 'PhysRevE',]

pie_fracs_journals = G.new_vertex_property("vector<double>") 
for v in range(int(G.num_vertices()/2)):
    top_journal_counts = [vertex_journal_counters[v][journal] if journal in vertex_journal_counters[v] else 0 for journal in top_journals ]
    top_journal_counts.append(len(vertex_to_clust_dict[v]) - sum(top_journal_counts))
    pie_fracs_journals[v] = np.array(top_journal_counts)/len(vertex_to_clust_dict[v])



journal_color_list = color_list[1:1+len(top_journals)]
journal_color_list += ["#808080"] # grey for "rest" category
#%%
G_filt = gt.GraphView(G, efilt=[edge_weights[e] > 0.05 for e in G.edges()])

gt.graph_draw(G_filt, pos=hier_lay,
              vertex_size=vertex_size,
              vertex_pen_width=0,
              edge_pen_width=gt.prop_to_size(edge_weights,mi=1,ma=20,power=1),
              nodesfirst=True,
              vertex_shape=vertex_shape,
              vertex_fill_color=vertex_fill_color,
                vertex_pie_fractions=pie_fracs_journals,
              vertex_pie_colors=journal_color_list,
              output_size=(1500,900),
               vertex_text=vertex_text,
               vertex_text_position=-0.1,
              # edge_text=edge_text,
              edge_color=[0.179, 0.203,0.210, 0.4],
              # fit_view=(1, 1, 12, 12) ,
               # fit_view=(0, 1, 1, 12) ,
              vertex_font_family='Helvetica',
              edge_font_family='Helvetica',
                # output=f'../figures/aps/complenet/complenet_forw_flow_prob_{orig_clust_name}_journals.pdf',
                # output=f'../figures/aps/complenet/complenet_forw_flow_prob_{orig_clust_name}_journals.svg',
              )


#%% draw legend

fig, ax = plt.subplots(figsize=(13,2.5))

x = 0
y = 0
xticks = []
xticklabels = []
for clas, color in zip(top_countries + ['others'], country_color_list):
# for clas, color in zip(range(len(color_list)), color_list):
    rect = patches.Rectangle((x,y), 10, 10, linewidth=0,
                                    facecolor=color)
    ax.add_patch(rect)
    xticks.append(x+5)
    xticklabels.append(clas)
    x += 15
    
ax.axis('on')     
ax.set_xlim([-5,x+15])    
ax.set_ylim([-5,15])    
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels, font='FreeSans')
ax.set_aspect('equal')


# plt.savefig(f'../figures/aps/complenet/complenet_forw_flow_prob_{orig_clust_name}_legend_countries.pdf')
# plt.savefig(f'../figures/aps/complenet/complenet_forw_flow_prob_{orig_clust_name}_legend_countries.png')
# plt.savefig(f'../figures/aps/complenet/complenet_forw_flow_prob_{orig_clust_name}_legend_countries.svg')

#%%
fig, ax = plt.subplots(figsize=(13,2.5))

x = 0
y = 0
xticks = []
xticklabels = []
for clas, color in zip(top_journals + ['others'], journal_color_list):
    rect = patches.Rectangle((x,y), 10, 10, linewidth=0,
                                    facecolor=color)
    ax.add_patch(rect)
    xticks.append(x+5)
    xticklabels.append(clas)
    x += 15
    
     
ax.set_xlim([-5,x+15])    
ax.set_ylim([-5,15])    
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels,font='FreeSans')
ax.set_aspect('equal')
        
# plt.savefig(f'../figures/aps/complenet/complenet_forw_flow_prob_{orig_clust_name}_legend_journals.pdf')
# plt.savefig(f'../figures/aps/complenet/complenet_forw_flow_prob_{orig_clust_name}_legend_journals.png')
# plt.savefig(f'../figures/aps/complenet/complenet_forw_flow_prob_{orig_clust_name}_legend_journals.svg')


#%% cheack identity of authors

clust_names = {}
for v in range(int(G.num_vertices()/2)):
    clust_names[v]=[(n, df_authors.iloc[net.node_to_label_dict[n]]['name'], 
                node_to_country_dict[n], 
                node_to_journal_dict[n]) for n in vertex_to_clust_dict[v]]

Counter([c for _, _,c,_ in clust_names[0]]).most_common()

# example from clust 1 https://journals.aps.org/pre/abstract/10.1103/PhysRevE.52.5166

#%% get article titles in clusts

df_doi_dates = pd.read_csv('../data/aps/df_doi_date.csv.gz', index_col=0,
                            parse_dates=['date'])


#%%
import tqdm 
import json
import nltk
from nltk.tokenize import word_tokenize
from collections import Counter
from nltk.corpus import stopwords
from datetime import datetime


stopWords = set(stopwords.words('english'))

stopWords.update([',',':','.','(',')'])


counters_bi = {}
counters_uni = {}

custom_dict = {'network' : 'networks',
               'transition' : 'transitions',
               'section' : 'sections',
               'well' : 'wells',
               'pulse' : 'pulses',
               'array' : 'arrays',
               'cristal' : 'cristals',
               'hole' : 'holes',
               'process' : 'processes',
               'polymer' : 'polymers',
               'resonance' : 'resonances'}

for clust_id in vertex_to_clust_dict.keys():
    print(clust_id)
    
    if clust_id in v_roots:
        date_start = datetime(2000,1,1)
        date_end = datetime(2010,1,1)
    elif clust_id in v_90s_list:
        date_start = datetime(1990,1,1)
        date_end = datetime(2000,1,1)
    elif clust_id in v_80s_list:
        date_start = datetime(1980,1,1)
        date_end = datetime(1990,1,1)
    elif clust_id in v_70s_list:
        date_start = datetime(1970,1,1)
        date_end = datetime(1980,1,1)
        

            
   
    clust_dois = set(net.events_table.loc[(net.events_table.source_nodes.isin(vertex_to_clust_dict[clust_id])) & \
                                          (net.events_table.target_nodes.isin(vertex_to_clust_dict[clust_id]))].doi)

    df_dois = df_doi_dates.loc[((df_doi_dates.date >= date_start) & \
                                    (df_doi_dates.date <= date_end)) & \
                                    (df_doi_dates.doi.isin(clust_dois))]
    

    # reade title from APS metadata
    doi_title_list = []
    for i,row in tqdm.tqdm(df_dois.iterrows(), total=df_dois.shape[0]):
        with open(row.file,'r') as fopen:
            d = json.load(fopen)
        assert d['articleType'] == 'article'
        
        doi_title_list.append({'title': d['title']['value'],
         'doi' : row.doi})

    df_doi_title = pd.DataFrame.from_dict(doi_title_list)        
    
    # remove html tags
    df_doi_title['title'] = df_doi_title.title.str.replace("<[^>]*>", "",)


    bigram_list = []
    unigram_list = []
    for i,row in tqdm.tqdm(df_doi_title.iterrows(), total=df_doi_title.shape[0]):
        
        words = [w for w in word_tokenize(row.title.lower())]
        
        #remove plurals
        words = [custom_dict[w] if w in custom_dict else w for w in words]
        
        bigram_list.extend([(b1,b2) for b1,b2 in nltk.bigrams(words) if b1 not in stopWords and b2 not in stopWords])
        unigram_list.extend([w  for w in words if w not in stopWords])
        
    counters_bi[clust_id] = Counter(bigram_list)
    counters_uni[clust_id] = Counter(unigram_list)    



#%% plot bigrams

fig, ax = plt.subplots(1,1,figsize=(23,5))

for vlist in [v_roots, v_90s_list, v_80s_list, v_70s_list]:
    for v in vlist:
        t = ax.text(hier_lay[v][0], 1*hier_lay[v][1],
                '\n'.join([' '.join(list(a) + [str(c)]) for a,c in  counters_bi[v].most_common(10)]),
                fontsize=3,
                font='FreeSans')
        
ax.set_xlim([0,X_max])
ax.set_ylim([1*Y_max,-3])

ax.axis('off')

# plt.savefig(f'../figures/aps/complenet/complenet_forw_flow_prob_{orig_clust_name}_bigrams_bothinclust.pdf')
# plt.savefig(f'../figures/aps/complenet/complenet_forw_flow_prob_{orig_clust_name}_bigrams_bothinclust.png')
# plt.savefig(f'../figures/aps/complenet/complenet_forw_flow_prob_{orig_clust_name}_bigrams_bothinclust.svg')

#%% plot unigrams

fig, ax = plt.subplots(1,1,figsize=(23,5))

for vlist in [v_roots, v_90s_list, v_80s_list, v_70s_list]:
    for v in vlist:
        t = ax.text(hier_lay[v][0], 1*hier_lay[v][1],
                '\n'.join([a + ' ' + str(c) for a,c in  counters_uni[v].most_common(10)]),
                fontsize=3,
                font='FreeSans')
        
ax.set_xlim([0,X_max])
ax.set_ylim([1*Y_max,-3])

ax.axis('off')

# plt.savefig(f'../figures/aps/complenet/complenet_forw_flow_prob_{orig_clust_name}_unigrams_bothinclust.pdf')
# plt.savefig(f'../figures/aps/complenet/complenet_forw_flow_prob_{orig_clust_name}_unigrams_bothinclust.png')
# plt.savefig(f'../figures/aps/complenet/complenet_forw_flow_prob_{orig_clust_name}_unigrams_bothinclust.svg')

#%% draw graph with bigrams
import cairo

vertex_text_position = G.new_vertex_property('double')
vertex_font_size = G.new_vertex_property('double')
vertex_text_offset = G.new_vertex_property('vector<double>')
vertex_font_weight = G.new_vertex_property('object')

#special cases (when there are several most common bigrams with the same count)
vtext_special = {1 : 'small-world\nnetworks',
                 15 : 'van der\nwaals',
                 13 : 'ladder\npolymers',
                 20:  'x-ray\ndiffraction',
                 28: 'electronic\nstructure'}

for v in range(int(G.num_vertices()/2)):
    vertex_text_position[v] = 1*np.pi/2
    vertex_text_position[v+int(G.num_vertices()/2)] = -0.1
    
    if v in vtext_special:
        vertex_text[v] = vtext_special[v]
    else:
        vertex_text[v] = '\n'.join(counters_bi[v].most_common(1)[0][0])
    # vertex_text[v+int(G.num_vertices()/2)] = 'test'

    vertex_text_offset[v] = [0.0,0.0]
    vertex_text_offset[v+int(G.num_vertices()/2)] = [0.0,0.0]
    
    vertex_font_size[v] = 17
    vertex_font_size[v+int(G.num_vertices()/2)] = 15

    vertex_font_weight[v] = cairo.FONT_WEIGHT_BOLD
    vertex_font_weight[v+int(G.num_vertices()/2)] = cairo.FONT_WEIGHT_NORMAL

#%%

G_filt = gt.GraphView(G, efilt=[edge_weights[e] > 0.05 for e in G.edges()])

gt.graph_draw(G_filt, pos=hier_lay,
              vertex_size=vertex_size,
              vertex_pen_width=0,
              edge_pen_width=gt.prop_to_size(edge_weights,mi=1,ma=20,power=1),
              nodesfirst=False,
              vertex_shape=vertex_shape,
              vertex_fill_color=vertex_fill_color,
                vertex_pie_fractions=pie_fracs_journals,
              vertex_pie_colors=journal_color_list,
              output_size=(1500,900),
              vertex_text=vertex_text,
               vertex_text_position=vertex_text_position ,
                vertex_text_offset=vertex_text_offset,
                vertex_font_size=vertex_font_size,
                vertex_font_weight=vertex_font_weight,
              # edge_text=edge_text,
              edge_color=[0.179, 0.203,0.210, 0.4],
              # fit_view=(1, 1, 12, 12) ,
               # fit_view=(0, 1, 1, 12) ,
              vertex_font_family='Helvetica',
              edge_font_family='Helvetica',
                # output=f'../figures/aps/complenet/complenet_forw_flow_prob_{orig_clust_name}_journals_bigram.pdf',
                # output=f'../figures/aps/complenet/complenet_forw_flow_prob_{orig_clust_name}_journals_bigram.svg',
              )

